{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Embedding Model for Satellite Images with Torchvision's ResNet\n",
    "\n",
    "This tutorial demonstrates how to train an embedding model with LightlyTrain on\n",
    "unlabeled data. The model is then used to generate embeddings from the images and\n",
    "visualize them in 2D. Embedding models are useful for a variety of tasks such as:\n",
    "\n",
    "- Image retrieval\n",
    "- Clustering\n",
    "- Outlier detection\n",
    "- Dataset curation\n",
    "\n",
    "For this tutorial we will use the [Aerial Image Dataset (AID)](https://captain-whu.github.io/AID/)\n",
    "which contains 30,000 satellite images from Google Earth grouped into 30 classes.\n",
    "\n",
    "![](https://captain-whu.github.io/AID/aid-dataset.png)\n",
    "Example images from the AID dataset [[source](https://captain-whu.github.io/AID/)].\n",
    "\n",
    "## Install Dependencies\n",
    "\n",
    "To get started, we first need to install the required dependencies:\n",
    "\n",
    "- `lightly-train` to train the embedding model and generate the embeddings\n",
    "- `umap-learn` to reduce the dimensionality of the embeddings for visualization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly-train umap-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## Download Dataset\n",
    "\n",
    "Next, we have to download the [AID dataset](https://captain-whu.github.io/AID/):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://www.kaggle.com/api/v1/datasets/download/jiayuanchengala/aid-scene-classification-datasets\n",
    "!unzip aid-scene-classification-datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "After unzipping, the dataset looks like this:\n",
    "\n",
    "```bash\n",
    "AID\n",
    "├── Airport\n",
    "│   ├── airport_100.jpg\n",
    "│   ├── ...\n",
    "│   └── airport_9.jpg\n",
    "├── BareLand\n",
    "│   ├── bareland_100.jpg\n",
    "│   ├── ...\n",
    "│   └── bareland_9.jpg\n",
    "├── ...\n",
    "└── Viaduct\n",
    "    ├── viaduct_100.jpg\n",
    "    ├── ...\n",
    "    └── viaduct_9.jpg\n",
    "```\n",
    "\n",
    "The images are grouped by class into subdirectories. LightlyTrain doesn't need the\n",
    "class information for training, but we will use it later to check the quality of the\n",
    "learned embeddings.\n",
    "\n",
    "## Train the Embedding Model\n",
    "\n",
    "Once the data is downloaded, we can start training the embedding model. We will use\n",
    "a lightweight ResNet18 model from torchvision for this. We also use bf16-mixed precision\n",
    "to speed up training. If your GPU does not support mixed precision, you can remove the\n",
    "`precision` argument.\n",
    "\n",
    "Training for 1000 epochs on a single RTX 4090 GPU takes about 5 hours. If you don't want\n",
    "to wait that long, you can reduce the number of epochs to 100. This will result in lower\n",
    "embedding quality, but only takes 30 minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightly_train\n",
    "\n",
    "lightly_train.train(\n",
    "    out=\"out/aid_resnet18_lightly_train\",\n",
    "    data=\"AID\",\n",
    "    model=\"torchvision/resnet18\",\n",
    "    epochs=1000,\n",
    "    precision=\"bf16-mixed\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Embed the Images\n",
    "\n",
    "Once the model is trained, we can use it to generate embeddings for the images. We will\n",
    "save the embeddings to a file called `embeddings_lightly_train.pt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightly_train\n",
    "\n",
    "lightly_train.embed(\n",
    "    out=\"embeddings_lightly_train.pt\",\n",
    "    data=\"AID\",\n",
    "    checkpoint=\"out/aid_resnet18_lightly_train/checkpoints/last.ckpt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "## Visualize the Embeddings\n",
    "\n",
    "Now that we have the embeddings, we can visualize them in 2D with [UMAP](https://umap-learn.readthedocs.io/en/latest/).\n",
    "UMAP is a dimension reduction technique that is well suited for visualizing\n",
    "high-dimensional data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import umap\n",
    "\n",
    "# Load the embeddings\n",
    "data = torch.load(\"embeddings_lightly_train.pt\", weights_only=True, map_location=\"cpu\")\n",
    "embeddings = data[\"embeddings\"]\n",
    "filenames = data[\"filenames\"]\n",
    "\n",
    "# Reduce dimensions with UMAP\n",
    "reducer = umap.UMAP()\n",
    "embedding_2d = reducer.fit_transform(embeddings)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.scatter(embedding_2d[:, 0], embedding_2d[:, 1], s=5)\n",
    "plt.title(\"UMAP of LightlyTrain Embeddings\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "We can see that the embeddings are nicely separated into well-defined clusters. Such\n",
    "visualizations are extremely useful when curating a dataset. They can quickly give you\n",
    "an overview of your data including outliers and duplicates. Furthermore, the clusters\n",
    "can be used to efficiently label your dataset.\n",
    "\n",
    "## Color the Clusters\n",
    "\n",
    "Let's check if the clusters make sense by coloring them according to the class labels\n",
    "that are available in this dataset. All filenames have the format `<class>/<image_name>.jpg`\n",
    "which lets us extract the class labels easily. Let's plot the embeddings again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Color embeddings based on class labels\n",
    "class_name_to_id = {\n",
    "    class_name: i\n",
    "    for i, class_name in enumerate({filename.split(\"/\")[0] for filename in filenames})\n",
    "}\n",
    "filename_to_class_id = {\n",
    "    filename: class_name_to_id[filename.split(\"/\")[0]] for filename in filenames\n",
    "}\n",
    "color = [filename_to_class_id[filename] for filename in filenames]\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.scatter(embedding_2d[:, 0], embedding_2d[:, 1], s=5, c=color, cmap=\"tab20\")\n",
    "plt.title(\"UMAP of LightlyTrain Embeddings Colored by Class\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "The embeddings are well separated by class with few outliers. The LightlyTrain model\n",
    "has learned meaningful embeddings **without** using any class information! For\n",
    "reference, we show a comparison to embeddings generated with an ImageNet supervised\n",
    "pretrained model below:\n",
    "\n",
    "![](https://raw.githubusercontent.com/lightly-ai/lightly-train/refs/heads/main/docs/source/_static/images/tutorials/embedding/umap_lightly_train_imagenet_colored.jpg)\n",
    "\n",
    "We can see that the clusters from the LightlyTrain embeddings are much more compact\n",
    "and have fewer overlaps. This means that the model has learned better representations\n",
    "and will make fewer mistakes for embedding-based tasks like image retrieval or\n",
    "clustering. This highlights how training an embedding model on the target dataset can\n",
    "improve the embeddings quality compared to using an off-the-shelf embedding model.\n",
    "\n",
    "## Conclusion\n",
    "\n",
    "In this tutorial we have learned how to train an embedding model using unlabeled data\n",
    "with LightlyTrain. We have also seen how to visualize the embeddings with UMAP and\n",
    "color them according to class labels. The visualizations show that the model has learned\n",
    "strong embeddings that capture the information of the images well and group similar\n",
    "images together. This is a great starting point for fine-tuning or any embedding-based\n",
    "task such as image retrieval, clustering, outlier detection or dataset curation.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
